#CPA-Enhancer block
import copy
import logging
import torch
import torch.nn as nn
from abc import ABCMeta
from collections import defaultdict
from typing import Iterable, List, Optional, Union
from .cpa_arch import CPA_arch

# class BaseModule(nn.Module, metaclass=ABCMeta):
#     """Base module for all modules in openmmlab. ``BaseModule`` is a wrapper of
#     ``torch.nn.Module`` with additional functionality of parameter
#     initialization. Compared with ``torch.nn.Module``, ``BaseModule`` mainly
#     adds three attributes.

#     - ``init_cfg``: the config to control the initialization.
#     - ``init_weights``: The function of parameter initialization and recording
#       initialization information.
#     - ``_params_init_info``: Used to track the parameter initialization
#       information. This attribute only exists during executing the
#       ``init_weights``.

#     Note:
#         :obj:`PretrainedInit` has a higher priority than any other
#         initializer. The loaded pretrained weights will overwrite
#         the previous initialized weights.

#     Args:
#         init_cfg (dict or List[dict], optional): Initialization config dict.
#     """

#     def __init__(self, init_cfg: Union[dict, List[dict], None] = None):
#         """Initialize BaseModule, inherited from `torch.nn.Module`"""

#         # NOTE init_cfg can be defined in different levels, but init_cfg
#         # in low levels has a higher priority.

#         super().__init__()
#         # define default value of init_cfg instead of hard code
#         # in init_weights() function
#         self._is_init = False

#         self.init_cfg = copy.deepcopy(init_cfg)

#         # Backward compatibility in derived classes
#         # if pretrained is not None:
#         #     warnings.warn('DeprecationWarning: pretrained is a deprecated \
#         #         key, please consider using init_cfg')
#         #     self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)

#     @property
#     def is_init(self):
#         return self._is_init

#     def init_weights(self):
#         """Initialize the weights."""

#         is_top_level_module = False
#         # check if it is top-level module
#         if not hasattr(self, '_params_init_info'):
#             # The `_params_init_info` is used to record the initialization
#             # information of the parameters
#             # the key should be the obj:`nn.Parameter` of model and the value
#             # should be a dict containing
#             # - init_info (str): The string that describes the initialization.
#             # - tmp_mean_value (FloatTensor): The mean of the parameter,
#             #       which indicates whether the parameter has been modified.
#             # this attribute would be deleted after all parameters
#             # is initialized.
#             self._params_init_info = defaultdict(dict)
#             is_top_level_module = True

#             # Initialize the `_params_init_info`,
#             # When detecting the `tmp_mean_value` of
#             # the corresponding parameter is changed, update related
#             # initialization information
#             for name, param in self.named_parameters():
#                 self._params_init_info[param][
#                     'init_info'] = f'The value is the same before and ' \
#                                    f'after calling `init_weights` ' \
#                                    f'of {self.__class__.__name__} '
#                 self._params_init_info[param][
#                     'tmp_mean_value'] = param.data.mean().cpu()

#             # pass `params_init_info` to all submodules
#             # All submodules share the same `params_init_info`,
#             # so it will be updated when parameters are
#             # modified at any level of the model.
#             for sub_module in self.modules():
#                 sub_module._params_init_info = self._params_init_info

#         module_name = self.__class__.__name__
#         if not self._is_init:
#             if self.init_cfg:
#                 init_cfgs = self.init_cfg
#                 if isinstance(self.init_cfg, dict):
#                     init_cfgs = [self.init_cfg]

#                 # PretrainedInit has higher priority than any other init_cfg.
#                 # Therefore we initialize `pretrained_cfg` last to overwrite
#                 # the previous initialized weights.
#                 # See details in https://github.com/open-mmlab/mmengine/issues/691 # noqa E501
#                 other_cfgs = []
#                 pretrained_cfg = []
#                 for init_cfg in init_cfgs:
#                     assert isinstance(init_cfg, dict)
#                     if (init_cfg['type'] == 'Pretrained'
#                             or init_cfg['type'] is PretrainedInit):
#                         pretrained_cfg.append(init_cfg)
#                     else:
#                         other_cfgs.append(init_cfg)

#                 initialize(self, other_cfgs)

#             for m in self.children():
#                 if is_model_wrapper(m) and not hasattr(m, 'init_weights'):
#                     m = m.module
#                 if hasattr(m, 'init_weights'):
#                     m.init_weights()
#                     # users may overload the `init_weights`
#                     update_init_info(
#                         m,
#                         init_info=f'Initialized by '
#                         f'user-defined `init_weights`'
#                         f' in {m.__class__.__name__} ')
#             if self.init_cfg and pretrained_cfg:
#                 initialize(self, pretrained_cfg)
#             self._is_init = True
#         else:
#             print_log(
#                 f'init_weights of {self.__class__.__name__} has '
#                 f'been called more than once.',
#                 logger='current',
#                 level=logging.WARNING)

#         if is_top_level_module:
#             self._dump_init_info()

#             for sub_module in self.modules():
#                 del sub_module._params_init_info

#     @master_only
#     def _dump_init_info(self):
#         """Dump the initialization information to a file named
#         `initialization.log.json` in workdir."""

#         logger = MMLogger.get_current_instance()
#         with_file_handler = False
#         # dump the information to the logger file if there is a `FileHandler`
#         for handler in logger.handlers:
#             if isinstance(handler, FileHandler):
#                 handler.stream.write(
#                     'Name of parameter - Initialization information\n')
#                 for name, param in self.named_parameters():
#                     handler.stream.write(
#                         f'\n{name} - {param.shape}: '
#                         f"\n{self._params_init_info[param]['init_info']} \n")
#                 handler.stream.flush()
#                 with_file_handler = True
#         if not with_file_handler:
#             for name, param in self.named_parameters():
#                 logger.info(
#                     f'\n{name} - {param.shape}: '
#                     f"\n{self._params_init_info[param]['init_info']} \n ")

#     def __repr__(self):
#         s = super().__repr__()
#         if self.init_cfg:
#             s += f'\ninit_cfg={self.init_cfg}'
#         return s



class PromptRestormer(nn.Module):    #BaseModule
    def __init__(self,c_in=3,c_out=3,dim=32):
        super(PromptRestormer, self).__init__()
        self.prompt_unet_arch = CPA_arch(c_in, c_out, dim)
    def forward(self,x):
        x_=x
        x=self.prompt_unet_arch(x)
        x=x+x_
        return x
    
#test CPA block
if __name__ == '__main__':
    aa=PromptRestormer(3,3)
    import numpy as np
    input=np.zeros((1,3,1280,1280))
    out=aa(torch.from_numpy(input).to(torch.float32))
    b=0